{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '6'\n",
    "from interaction_effects import utils\n",
    "utils.set_up_environment()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:absl:Overwrite dataset info from restored data version.\n",
      "INFO:absl:Reusing dataset glue (/homes/gws/psturm/tensorflow_datasets/glue/mrpc/0.0.2)\n",
      "INFO:absl:Constructing tf.data.Dataset for split None, from /homes/gws/psturm/tensorflow_datasets/glue/mrpc/0.0.2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train for 115 steps, validate for 7 steps\n",
      "Epoch 1/2\n",
      "115/115 [==============================] - 65s 562ms/step - loss: 0.5497 - accuracy: 0.7266 - val_loss: 0.3931 - val_accuracy: 0.8554\n",
      "Epoch 2/2\n",
      "115/115 [==============================] - 42s 367ms/step - loss: 0.2915 - accuracy: 0.8833 - val_loss: 0.4309 - val_accuracy: 0.8260\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets\n",
    "from transformers import *\n",
    "\n",
    "# Load dataset, tokenizer, model from pretrained model/vocabulary\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n",
    "model = TFBertForSequenceClassification.from_pretrained('bert-base-cased')\n",
    "data = tensorflow_datasets.load('glue/mrpc')\n",
    "\n",
    "# Prepare dataset for GLUE as a tf.data.Dataset instance\n",
    "train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, max_length=128, task='mrpc')\n",
    "valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, max_length=128, task='mrpc')\n",
    "train_dataset = train_dataset.shuffle(100).batch(32).repeat(2)\n",
    "valid_dataset = valid_dataset.batch(64)\n",
    "\n",
    "# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule \n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)\n",
    "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\n",
    "model.compile(optimizer=optimizer, loss=loss, metrics=[metric])\n",
    "\n",
    "# Train and evaluate using tf.keras.Model.fit()\n",
    "history = model.fit(train_dataset, epochs=2, steps_per_epoch=115,\n",
    "                    validation_data=valid_dataset, validation_steps=7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'idx': <tf.Tensor: id=57251, shape=(), dtype=int32, numpy=201>, 'label': <tf.Tensor: id=57252, shape=(), dtype=int64, numpy=1>, 'sentence1': <tf.Tensor: id=57253, shape=(), dtype=string, numpy=b'Tibco has used the Rendezvous name since 1994 for several of its technology products , according to the Palo Alto , California company .'>, 'sentence2': <tf.Tensor: id=57254, shape=(), dtype=string, numpy=b'Tibco has used the Rendezvous name since 1994 for several of its technology products , it said .'>}\n"
     ]
    }
   ],
   "source": [
    "for item in data['train'].take(1):\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "({'input_ids': <tf.Tensor: id=57263, shape=(32, 128), dtype=int32, numpy=\n",
      "array([[  101,  2892,  1209, ...,     0,     0,     0],\n",
      "       [  101,  1130,   170, ...,     0,     0,     0],\n",
      "       [  101,   138, 19959, ...,     0,     0,     0],\n",
      "       ...,\n",
      "       [  101,   138,  2370, ...,     0,     0,     0],\n",
      "       [  101,  1124,  7005, ...,     0,     0,     0],\n",
      "       [  101,  1153,  4567, ...,     0,     0,     0]], dtype=int32)>, 'attention_mask': <tf.Tensor: id=57262, shape=(32, 128), dtype=int32, numpy=\n",
      "array([[1, 1, 1, ..., 0, 0, 0],\n",
      "       [1, 1, 1, ..., 0, 0, 0],\n",
      "       [1, 1, 1, ..., 0, 0, 0],\n",
      "       ...,\n",
      "       [1, 1, 1, ..., 0, 0, 0],\n",
      "       [1, 1, 1, ..., 0, 0, 0],\n",
      "       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)>, 'token_type_ids': <tf.Tensor: id=57264, shape=(32, 128), dtype=int32, numpy=\n",
      "array([[0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       ...,\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>}, <tf.Tensor: id=57265, shape=(32,), dtype=int64, numpy=\n",
      "array([0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1,\n",
      "       1, 1, 0, 1, 1, 1, 1, 1, 1, 0])>)\n"
     ]
    }
   ],
   "source": [
    "for item in train_dataset.take(1):\n",
    "    print(item)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
